# -*- coding: utf-8 -*-

import numpy as np
from scipy.special import erf #pjli 2022/3/2

np.seterr(divide='ignore',invalid='ignore') #Mask overflow error

# pjli 20220323  note
# case None theta-h
# case 1 fh-h
# case 2 gr-r
# case 3 kr-h

def fin(x,para,model):
    # integrand
    thr,ths,alp,npar = para
    if model=='HP':
        y=np.exp(-x)*(1/np.cosh(x/npar))**2
    elif model=='AT':
        y=np.exp(-x)/(npar**2+x**2)
    elif model=='GD':
        y=np.exp(-x)*np.cosh(x/npar)/((np.sinh(x/npar))**2+1)
    elif model=='SG':
        y=np.exp(-x+x/npar)/(np.exp(x/npar)+1)**2
    elif model=='CA':
        y=np.exp(-x)/((x/npar)**2+1)**1.5
    return y

def IN_LO(up,dn,para,model):
    # integral operation
    thr,ths,alp,npar = para
    # npar=2
    # up=1
    # dn=0
    ndr=1.0e4
    dr=(up-dn)/ndr
    min_arr=np.linspace(dn,up,int(ndr),endpoint=False).astype('float64')
    max_arr=np.linspace(dn+dr,up+dr,int(ndr),endpoint=False).astype('float64')
    fmin_arr=fin(min_arr,para,model)
    fmax_arr=fin(max_arr,para,model)  
    area=np.sum((fmin_arr+fmax_arr)*dr/2)
    return area

def Hmin(para,model):
    thr,ths,alp,npar = para
    # Smax=0.95
    if model=='HP':
        #Smax=0.98
        #y=np.exp(alp+npar*np.arctanh((0.5-Smax)*2))
        y=1.6
    elif model=='AT':
        #Smax=0.91
        #y=np.exp(alp+npar*np.tan((0.5-Smax)*np.pi))
        y=1.3
    elif model=='GD':
        #Smax=0.98
        #y=np.exp(alp+npar*np.arcsinh(np.tan((0.5-Smax)*np.pi)))
        y=1.6
    elif model=='SG':
        #Smax=0.98
        #y=np.exp(alp-npar*np.log(1/((1-Smax)*1)-1))
        y=1.6
    elif model=='CA':
        #Smax=0.91
        #y=np.exp(alp+((1-2*Smax)**2/(4*Smax*(Smax-1)))*npar)
        y=1.3
    return y


def IN_pre(h,Se,para,model):
    thr,ths,alp,npar = para
    hmax=1e4
    hmin=4#np.exp(alp)*np.exp(-np.sqrt(1-npar**2)-1)
    hmin=Hmin(para,model)
    kr_est=[]
    for i in range(len(h)):
        if np.array(h)[i]<hmin:
            kr=1
        else:
            up=np.log(np.array(h)[i])-alp
            dn=np.log(hmax)-alp
            upmax=np.log(hmin)-alp
            y1=IN_LO(up,dn,para,model)
            y2=IN_LO(upmax,dn,para,model)
            Q=(y1/y2)**2
            kr=np.array(Se)[i]**0.5*Q
        kr_est.append(kr)
    return kr_est
        

def BC(h,para,case=None):
    thr,ths,psi_s,npar = para
    theta_est = []
    for i in range(len(h)):
        if np.array(h)[i]>psi_s:
            theta = thr+(ths-thr)*(psi_s/np.array(h)[i])**npar
        else:
            theta = ths
        theta_est = np.append(theta_est,theta)
    if case==1:
        theta_est = []
        for i in range(len(h)):
            if np.array(h)[i]>psi_s:
                theta = (ths-thr)*(npar/np.array(h)[i])*(psi_s/np.array(h)[i])**npar
            else:
                theta = 0
            theta_est = np.append(theta_est,theta)
    elif case==2:
        theta_est = []
        for i in range(len(h)):
            if np.array(h)[i]<(0.149/psi_s):
                theta = (ths-thr)*(npar/np.array(h)[i])*(np.array(h)[i]*psi_s/0.149)**npar
            else:
                theta = 0
            theta_est = np.append(theta_est,theta)
    elif case==3:
        kr_est=[]
        for i in range(len(h)):
            if np.array(h)[i]>psi_s:
                kr=(psi_s/np.array(h)[i])**(3*npar+2)
            else:
                kr=1
            kr_est.append(kr)
        theta_est = kr_est
    return theta_est

def VG(h,para,case=None):
    # case=1,--> f(h);case=2, --> f(r); case=3, --> Kr(h)
    thr,ths,alp,npar = para
    theta_est = (thr+(ths-thr)/(1+(alp*abs(h))**npar)**(1-1/npar))
    if case==1:
        theta_est = (ths-thr)*(npar-1)*alp*(alp*h)**(npar-1)*(1+(alp*h)**npar)**(1/npar-2)
    elif case==2:
        theta_est = (ths-thr)*(npar-1)*alp*(alp*0.149/h)**(npar-1)*(1+(alp*0.149/h)**npar)**(1/npar-2)*(0.149/h**2)
    elif case==3:
        m=1-1/npar
        Se=(theta_est-thr)/(ths-thr)
        kr=Se**0.5*(1-(1-Se**(1/m))**m)**2
        theta_est = kr
    return theta_est

def KO(h,para,case=None):
    thr,ths,alp,npar = para
    def erfc(x):
        y=1-erf(x) #pjli 20222/3/2
        return y
    theta_est = (thr+(ths-thr)*0.5* erfc(np.log(abs(h)/alp)/(np.sqrt(2)*npar)))
    if case==1:
        theta_est = (ths-thr)/(np.sqrt(2*np.pi)*npar*h)*np.exp(-(np.log(h/alp))**2/(2*npar**2))
    elif case==2:
        theta_est = (ths-thr)/(np.sqrt(2*np.pi)*npar*h)*np.exp(-(np.log(0.149/(alp*h)))**2/(2*npar**2))
    elif case==3:
        Se=(theta_est-thr)/(ths-thr)
        Vmid=(npar**2+np.log(h/alp))/np.sqrt(2)/npar
        Q=(erfc(Vmid)/2)**2
        kr=Se**0.5*Q
        theta_est = kr
    return theta_est

def HP(h,para,case=None):
    thr,ths,alp,npar = para
    theta_est = thr+(ths-thr)*0.5*(1-np.tanh((np.log(abs(h))-alp)/npar))
    if case==1:
        theta_est = (ths-thr)/(2*np.pi*h*npar)/(np.cosh((np.log(abs(h))-alp)/npar))**2
    elif case==2:
        theta_est = (ths-thr)/(2*np.pi*h*npar)/(np.cosh((np.log(0.149/np.exp(alp)/abs(h)))/npar))**2
    elif case==3:
        Se=(theta_est-thr)/(ths-thr)
        model='HP'
        theta_est = IN_pre(h,Se,para,model)
    return theta_est

def AT(h,para,case=None):
    thr,ths,alp,npar = para
    theta_est = thr+(ths-thr)* (- np.arctan((np.log(abs(h))-alp)/npar)/np.pi+0.5)
    if case==1:
        theta_est = (ths-thr)/(np.pi*h)*npar/(npar**2+(np.log(abs(h))-alp)**2)
    elif case==2:
        theta_est = (ths-thr)/(np.pi*h)*npar/(npar**2+(np.log(0.149/np.exp(alp)/abs(h)))**2)
    elif case==3:
        Se=(theta_est-thr)/(ths-thr)
        model='AT'
        theta_est = IN_pre(h,Se,para,model)
    return theta_est

def GD(h,para,case=None):
    thr,ths,alp,npar = para
    def gd(x):
        return np.arctan(np.sinh(x))
    theta_est = thr + (ths-thr)/2 - (ths-thr)/np.pi * gd((np.log(h)-alp)/npar)
    if case==1:
        theta_est = (ths-thr)/(np.pi*h*npar)*np.cosh((np.log(h)-alp)/npar)/(np.sinh((np.log(h)-alp)/npar)**2+1)
    elif case==2:
        theta_est = (ths-thr)/(np.pi*h*npar)*np.cosh((np.log(0.149/np.exp(alp)/abs(h)))/npar)/(np.sinh((np.log(0.149/np.exp(alp)/abs(h)))/npar)**2+1)
    elif case==3:
        Se=(theta_est-thr)/(ths-thr)
        model='GD'
        theta_est = IN_pre(h,Se,para,model)
    return theta_est

def SG(h,para,case=None):
    thr,ths,alp,npar = para
    theta_est = ths - (ths-thr) * (1+np.exp(-(np.log(h)-alp)/npar))**(-1)
    if case==1:
        theta_est = (ths-thr)/(h*npar)*np.exp((np.log(h)-alp)/npar)/(np.exp((np.log(h)-alp)/npar)+1)**2
    elif case==2:
        theta_est = (ths-thr)/(h*npar)*np.exp((np.log(0.149/np.exp(alp)/abs(h)))/npar)/(np.exp((np.log(0.149/np.exp(alp)/abs(h)))/npar)+1)**2
    elif case==3:
        Se=(theta_est-thr)/(ths-thr)
        model='SG'
        theta_est = IN_pre(h,Se,para,model)
    return theta_est

def CA(h,para,case=None):
    thr,ths,alp,npar = para
    mid = (np.log(h)-alp)/npar   #中间变量
    theta_est = thr + (ths-thr)/2 - (ths-thr)/2 * mid * (1+mid**2)**(-1/2)
    theta_est[np.isnan(theta_est)] = ths
    if case==1:
        theta_est = (ths-thr)/(2*h*npar)/(((np.log(h)-alp)/npar)**2+1)**1.5
    elif case==2:
        theta_est = (ths-thr)/(2*h*npar)/(((np.log(0.149/np.exp(alp)/abs(h)))/npar)**2+1)**1.5
    elif case==3:
        Se=(theta_est-thr)/(ths-thr)
        model='CA'
        theta_est = IN_pre(h,Se,para,model)
    return theta_est
    
    